{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Recurrent Neural Networks with Word Embeddings\n", "\n", "based on:\n", "\n", "http://deeplearning.net/tutorial/rnnslu.html\n", "\n", "This notebook takes the material from that page, downloads it, and changes it to work with Python3." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, we delete any left over stuff from a previous run:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": true }, "outputs": [], "source": [ "!rm -rf is13 atis.pkl*" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we get the example code:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cloning into 'is13'...\n", "remote: Counting objects: 71, done.\u001b[K\n", "remote: Total 71 (delta 0), reused 0 (delta 0), pack-reused 71\u001b[K\n", "Unpacking objects: 100% (71/71), done.\n", "Checking connectivity... done.\n" ] } ], "source": [ "!git clone https://github.com/mesnilgr/is13.git" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and get the sample data:" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " % Total % Received % Xferd Average Speed Time Time Time Current\n", " Dload Upload Total Spent Left Speed\n", "100 214k 100 214k 0 0 1669k 0 --:--:-- --:--:-- --:--:-- 1673k\n" ] } ], "source": [ "!curl -o atis.pkl.gz http://www-etud.iro.umontreal.ca/~mesnilgr/atis/atis.pkl.gz" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "collapsed": false }, "outputs": [], "source": [ "!gunzip atis.pkl.gz" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we convert the example Python2 code to Python3:" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "RefactoringTool: Skipping implicit fixer: buffer\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "RefactoringTool: Skipping implicit fixer: idioms\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "RefactoringTool: Skipping implicit fixer: set_literal\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "RefactoringTool: Skipping implicit fixer: ws_comma\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "root: Generating grammar tables from /usr/lib/python3.4/lib2to3/PatternGrammar.txt\n", "RefactoringTool: Refactored is13/data/load.py\n", "RefactoringTool: Refactored is13/examples/elman-forward.py\n", "RefactoringTool: Refactored is13/examples/jordan-forward.py\n", "RefactoringTool: Refactored is13/metrics/accuracy.py\n", "RefactoringTool: No changes to is13/rnn/elman.py\n", "RefactoringTool: No changes to is13/rnn/jordan.py\n", "RefactoringTool: Refactored is13/utils/tools.py\n", "RefactoringTool: Files that were modified:\n", "RefactoringTool: is13/data/load.py\n", "RefactoringTool: is13/examples/elman-forward.py\n", "RefactoringTool: is13/examples/jordan-forward.py\n", "RefactoringTool: is13/metrics/accuracy.py\n", "RefactoringTool: is13/rnn/elman.py\n", "RefactoringTool: is13/rnn/jordan.py\n", "RefactoringTool: is13/utils/tools.py\n" ] } ], "source": [ "!2to3-3.4 -w is13" ] }, { "cell_type": "code", "execution_count": 62, "metadata": { "collapsed": false }, "outputs": [], "source": [ "!sed -i \"s/load(f)/load(f, encoding=\\\"latin1\\\")/\" is13/data/load.py" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [], "source": [ "!sed -i \"s/\\//\\/\\//g\" is13/utils/tools.py" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [], "source": [ "!sed -i \"s/open(filename).read()/bytes(open(filename).read(), \\\"utf-8\\\")/\" is13/metrics/accuracy.py" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [], "source": [ "!sed -i \"s/stdout.split/str(stdout, \\\"utf-8\\\").split/\" is13/metrics/accuracy.py" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Example\n", "\n", "We import the libraries:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import numpy\n", "import time\n", "import sys\n", "import subprocess\n", "import os\n", "import random\n", "\n", "from is13.data import load\n", "from is13.rnn.elman import model\n", "from is13.metrics.accuracy import conlleval\n", "from is13.utils.tools import shuffle, minibatch, contextwin\n", "from functools import reduce" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and run an experiment:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch... 0\n", " NEW BEST: epoch 0 valid F1 86.65 best test F1 83.13 \n", "Epoch... 1\n", " NEW BEST: epoch 1 valid F1 89.99 best test F1 87.25 \n", "Epoch... 2\n", " NEW BEST: epoch 2 valid F1 92.8 best test F1 89.52 \n", "Epoch... 3\n", " " ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 52\u001b[0m \u001b[0mlabels\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtrain_y\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 53\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mword_batch\u001b[0m \u001b[1;33m,\u001b[0m \u001b[0mlabel_last_word\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mwords\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 54\u001b[1;33m \u001b[0mrnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mword_batch\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlabel_last_word\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0ms\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'clr'\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 55\u001b[0m \u001b[0mrnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnormalize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 56\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0ms\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'verbose'\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m/usr/local/lib/python3.4/dist-packages/theano/compile/function_module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 593\u001b[0m \u001b[0mt0_fn\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 594\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 595\u001b[1;33m \u001b[0moutputs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 596\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 597\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfn\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'position_of_error'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m/usr/local/lib/python3.4/dist-packages/theano/scan_module/scan_op.py\u001b[0m in \u001b[0;36mrval\u001b[1;34m(p, i, o, n, allow_gc)\u001b[0m\n\u001b[0;32m 670\u001b[0m def rval(p=p, i=node_input_storage, o=node_output_storage, n=node,\n\u001b[0;32m 671\u001b[0m allow_gc=allow_gc):\n\u001b[1;32m--> 672\u001b[1;33m \u001b[0mr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mp\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mn\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mo\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 673\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mo\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mnode\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moutputs\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 674\u001b[0m \u001b[0mcompute_map\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mo\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m/usr/local/lib/python3.4/dist-packages/theano/scan_module/scan_op.py\u001b[0m in \u001b[0;36m\u001b[1;34m(node, args, outs)\u001b[0m\n\u001b[0;32m 659\u001b[0m \u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 660\u001b[0m \u001b[0mouts\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 661\u001b[1;33m self, node)\n\u001b[0m\u001b[0;32m 662\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mImportError\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtheano\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgof\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcmodule\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mMissingGXX\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 663\u001b[0m \u001b[0mp\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexecute\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mscan_perform.pyx\u001b[0m in \u001b[0;36mtheano.scan_module.scan_perform.perform (/home/dblank/.theano/compiledir_Linux-3.13--generic-x86_64-with-Ubuntu-14.04-trusty-x86_64-3.4.0-64/scan_perform/mod.cpp:3537)\u001b[1;34m()\u001b[0m\n", "\u001b[1;32m/usr/local/lib/python3.4/dist-packages/theano/gof/op.py\u001b[0m in \u001b[0;36mrval\u001b[1;34m(p, i, o, n)\u001b[0m\n\u001b[0;32m 766\u001b[0m \u001b[1;31m# default arguments are stored in the closure of `rval`\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 767\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mrval\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mp\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mp\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mi\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mnode_input_storage\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mo\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mnode_output_storage\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mn\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mnode\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 768\u001b[1;33m \u001b[0mr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mp\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mn\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mo\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 769\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mo\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mnode\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moutputs\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 770\u001b[0m \u001b[0mcompute_map\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mo\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m/usr/local/lib/python3.4/dist-packages/theano/tensor/blas.py\u001b[0m in \u001b[0;36mperform\u001b[1;34m(self, node, inputs, out_storage)\u001b[0m\n\u001b[0;32m 396\u001b[0m \u001b[1;31m# overwrite_y=self.inplace)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 397\u001b[0m out_storage[0][0] = gemv(alpha, A.T, x, beta, y,\n\u001b[1;32m--> 398\u001b[1;33m overwrite_y=self.inplace, trans=True)\n\u001b[0m\u001b[0;32m 399\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 400\u001b[0m \u001b[0mout\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnumpy\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mA\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "s = {'fold':3, # 5 folds 0,1,2,3,4\n", " 'lr':0.0627142536696559,\n", " 'verbose':1,\n", " 'decay':False, # decay on the learning rate if improvement stops\n", " 'win':7, # number of words in the context window\n", " 'bs':9, # number of backprop through time steps\n", " 'nhidden':100, # number of hidden units\n", " 'seed':345,\n", " 'emb_dimension':100, # dimension of word embedding\n", " 'nepochs':50}\n", "\n", "folder = \"is13/examples\"\n", "if not os.path.exists(folder): os.mkdir(folder)\n", "\n", "# load the dataset\n", "train_set, valid_set, test_set, dic = load.atisfold(s['fold'])\n", "idx2label = dict((k,v) for v,k in dic['labels2idx'].items())\n", "idx2word = dict((k,v) for v,k in dic['words2idx'].items())\n", "\n", "train_lex, train_ne, train_y = train_set\n", "valid_lex, valid_ne, valid_y = valid_set\n", "test_lex, test_ne, test_y = test_set\n", "\n", "vocsize = len(set(reduce(\\\n", " lambda x, y: list(x)+list(y),\\\n", " train_lex+valid_lex+test_lex)))\n", "\n", "nclasses = len(set(reduce(\\\n", " lambda x, y: list(x)+list(y),\\\n", " train_y+test_y+valid_y)))\n", "\n", "nsentences = len(train_lex)\n", "\n", "# instanciate the model\n", "numpy.random.seed(s['seed'])\n", "random.seed(s['seed'])\n", "rnn = model( nh = s['nhidden'],\n", " nc = nclasses,\n", " ne = vocsize,\n", " de = s['emb_dimension'],\n", " cs = s['win'] )\n", "\n", "# train with early stopping on validation set\n", "best_f1 = -numpy.inf\n", "s['clr'] = s['lr']\n", "for e in range(s['nepochs']):\n", " # shuffle\n", " print(\"Epoch...\", e)\n", " sys.stdout.flush()\n", " shuffle([train_lex, train_ne, train_y], s['seed'])\n", " s['ce'] = e\n", " tic = time.time()\n", " for i in range(nsentences):\n", " cwords = contextwin(train_lex[i], s['win'])\n", " words = [numpy.asarray(x).astype('int32') for x in minibatch(cwords, s['bs'])]\n", " labels = train_y[i]\n", " for word_batch , label_last_word in zip(words, labels):\n", " rnn.train(word_batch, label_last_word, s['clr'])\n", " rnn.normalize()\n", " if s['verbose']:\n", " print('[learning] epoch %i >> %2.2f%%'%(e,(i+1)*100./nsentences),'completed in %.2f (sec) <<\\r'%(time.time()-tic), end=' ')\n", " sys.stdout.flush()\n", "\n", " # evaluation // back into the real world : idx -> words\n", " predictions_test = [ [idx2label[x] for x in rnn.classify(numpy.asarray(contextwin(x, s['win'])).astype('int32'))]\\\n", " for x in test_lex ]\n", " groundtruth_test = [ [idx2label[x] for x in y] for y in test_y ]\n", " words_test = [ [idx2word[x] for x in w] for w in test_lex]\n", "\n", " predictions_valid = [ [idx2label[x] for x in rnn.classify(numpy.asarray(contextwin(x, s['win'])).astype('int32'))]\\\n", " for x in valid_lex ]\n", " groundtruth_valid = [ [idx2label[x] for x in y] for y in valid_y ]\n", " words_valid = [ [idx2word[x] for x in w] for w in valid_lex]\n", "\n", " # evaluation // compute the accuracy using conlleval.pl\n", " res_test = conlleval(predictions_test, groundtruth_test, words_test, folder + '/current.test.txt')\n", " res_valid = conlleval(predictions_valid, groundtruth_valid, words_valid, folder + '/current.valid.txt')\n", "\n", " if res_valid['f1'] > best_f1:\n", " rnn.save(folder)\n", " best_f1 = res_valid['f1']\n", " if s['verbose']:\n", " print('NEW BEST: epoch', e, 'valid F1', res_valid['f1'], 'best test F1', res_test['f1'], ' '*20)\n", " s['vf1'], s['vp'], s['vr'] = res_valid['f1'], res_valid['p'], res_valid['r']\n", " s['tf1'], s['tp'], s['tr'] = res_test['f1'], res_test['p'], res_test['r']\n", " s['be'] = e\n", " subprocess.call(['mv', folder + '/current.test.txt', folder + '/best.test.txt'])\n", " subprocess.call(['mv', folder + '/current.valid.txt', folder + '/best.valid.txt'])\n", " else:\n", " print('')\n", "\n", " # learning rate decay if no improvement in 10 epochs\n", " if s['decay'] and abs(s['be']-s['ce']) >= 10: s['clr'] *= 0.5 \n", " if s['clr'] < 1e-5: break\n", "\n", "print('BEST RESULT: epoch', e, 'valid F1', s['vf1'], 'best test F1', s['tf1'], 'with the model', folder)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.4.0" } }, "nbformat": 4, "nbformat_minor": 0 }